import btk
import unittest
import _TDDConfigure

class MergeAcquisitionFilterTest(unittest.TestCase):
  def test_NoInput(self):
    merger = btk.btkMergeAcquisitionFilter()
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrameNumber(), 0)
    self.assertEqual(output.GetAnalogFrameNumber(), 0)
    self.assertEqual(output.GetPointNumber(), 0)
    self.assertEqual(output.GetAnalogNumber(), 0)

  def test_OneInput(self):
    test = btk.btkAcquisition()
    test.Init(6, 50, 2, 2)
    test.SetPointFrequency(100)
    merger = btk.btkMergeAcquisitionFilter()
    merger.SetInput(0, test)
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrequency(), 100)
    self.assertEqual(output.GetAnalogFrequency(), 200)
    self.assertEqual(output.GetPointNumber(), 6)
    self.assertEqual(output.GetAnalogNumber(), 2)
    self.assertEqual(output.GetPointFrameNumber(), 50)
    self.assertEqual(output.GetAnalogFrameNumber(), 100)
    ptPoint = output.GetPoint(0)
    self.assertEqual(ptPoint.GetFrameNumber(), 50)
    ptAnalog = output.GetAnalog(0)
    self.assertEqual(ptAnalog.GetFrameNumber(), 100)

  def test_TwoInputsFromScratch(self):
    i1 = btk.btkAcquisition()
    i1.Init(6, 50)
    i2 = btk.btkAcquisition()
    i2.Init(6, 50)
    i2.SetPointFrequency(100)
    merger = btk.btkMergeAcquisitionFilter()
    merger.SetInput(0, i1)
    merger.SetInput(1, i2)
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrequency(), 100)
    self.assertEqual(output.GetAnalogFrequency(), 100)
    self.assertEqual(output.GetPointNumber(), 12)
    self.assertEqual(output.GetAnalogNumber(), 0)
    self.assertEqual(output.GetPointFrameNumber(), 50)
    self.assertEqual(output.GetAnalogFrameNumber(), 50)
    inc = 1
    it = output.BeginPoint()
    while it != output.EndPoint():
      if (inc <= 6):
        self.assertEqual(it.value().GetLabel(), 'uname*' + str(inc))
      else:
        self.assertEqual(it.value().GetLabel(), 'uname*' + str(inc - 6) + '_2')
      inc = inc + 1
      it.incr()

  def test_TwoInputsFromScratch_FirstFrame1(self):
    i1 = btk.btkAcquisition()
    i1.Init(6, 50)
    i2 = btk.btkAcquisition()
    i2.Init(6, 50)
    i2.SetFirstFrame(25)
    i2.SetPointFrequency(100)
    merger = btk.btkMergeAcquisitionFilter()
    merger.SetInput(0, i1)
    merger.SetInput(1, i2)
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrequency(), 100)
    self.assertEqual(output.GetAnalogFrequency(), 100)
    self.assertEqual(output.GetFirstFrame(), 1)
    self.assertEqual(output.GetPointNumber(), 12)
    self.assertEqual(output.GetAnalogNumber(), 0)
    self.assertEqual(output.GetPointFrameNumber(), 74)
    self.assertEqual(output.GetAnalogFrameNumber(), 74)

  def test_TwoInputsFromScratch_FirstFrame2(self):
    i1 = btk.btkAcquisition()
    i1.Init(6, 50)
    i1.SetFirstFrame(25)
    i2 = btk.btkAcquisition()
    i2.Init(6, 50)
    i2.SetPointFrequency(100)
    merger = btk.btkMergeAcquisitionFilter()
    merger.SetInput(0, i1)
    merger.SetInput(1, i2)
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrequency(), 100)
    self.assertEqual(output.GetAnalogFrequency(), 100)
    self.assertEqual(output.GetFirstFrame(), 1)
    self.assertEqual(output.GetPointNumber(), 12)
    self.assertEqual(output.GetAnalogNumber(), 0)
    self.assertEqual(output.GetPointFrameNumber(), 74)
    self.assertEqual(output.GetAnalogFrameNumber(), 74)

  def test_TwoInputsFromScratch_Merging1(self):
    i1 = btk.btkAcquisition()
    i1.Init(6, 50)
    i1.SetFirstFrame(51)  
    i2 = btk.btkAcquisition()
    i2.Init(6, 50)
    i2.SetPointFrequency(100)
    merger = btk.btkMergeAcquisitionFilter()
    merger.SetInput(0, i1)
    merger.SetInput(1, i2)
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrequency(), 100)
    self.assertEqual(output.GetAnalogFrequency(), 100)
    self.assertEqual(output.GetFirstFrame(), 1)
    self.assertEqual(output.GetPointNumber(), 6)
    self.assertEqual(output.GetAnalogNumber(), 0)
    self.assertEqual(output.GetPointFrameNumber(), 100)
    self.assertEqual(output.GetAnalogFrameNumber(), 100)
    inc = 1
    it = output.BeginPoint()
    while it != output.EndPoint():
      self.assertEqual(it.value().GetLabel(), 'uname*' + str(inc))
      inc = inc + 1
      it.incr()

  def test_TwoInputsFromScratch_Merging2(self):
    i1 = btk.btkAcquisition()
    i1.Init(6, 50)
    i2 = btk.btkAcquisition()
    i2.Init(6, 50)
    i2.SetPointFrequency(100)
    i2.SetFirstFrame(51)
    merger = btk.btkMergeAcquisitionFilter()
    merger.SetInput(0, i1)
    merger.SetInput(1, i2)
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrequency(), 100)
    self.assertEqual(output.GetAnalogFrequency(), 100)
    self.assertEqual(output.GetFirstFrame(), 1)
    self.assertEqual(output.GetPointNumber(), 6)
    self.assertEqual(output.GetAnalogNumber(), 0)
    self.assertEqual(output.GetPointFrameNumber(), 100)
    self.assertEqual(output.GetAnalogFrameNumber(), 100)
    inc = 1
    it = output.BeginPoint()
    while it != output.EndPoint():
      self.assertEqual(it.value().GetLabel(), 'uname*' + str(inc))
      inc = inc + 1
      it.incr()

  def test_TwinsFromFile_Concat(self):
    reader = btk.btkAcquisitionFileReader()
    reader.SetFilename(_TDDConfigure.C3DFilePathIN + 'sample09/PlugInC3D.c3d')
    input = reader.GetOutput()
    merger = btk.btkMergeAcquisitionFilter()
    merger.SetInput(0, input)
    merger.SetInput(1, input)
    merger.Update()
    output = merger.GetOutput()
    self.assertEqual(output.GetPointFrequency(), input.GetPointFrequency())
    self.assertEqual(output.GetAnalogFrequency(), input.GetAnalogFrequency())
    self.assertEqual(output.GetPointNumber(), input.GetPointNumber() * 2)
    self.assertEqual(output.GetAnalogNumber(), input.GetAnalogNumber() * 2)
    self.assertEqual(output.GetPointFrameNumber(), input.GetPointFrameNumber())
    self.assertEqual(output.GetAnalogFrameNumber(), input.GetAnalogFrameNumber())
    self.assertEqual(output.GetEventNumber(), input.GetEventNumber())
    for i in range(0,input.GetPointNumber()):
      self.assertEqual(output.GetPoint(i).GetLabel(), input.GetPoint(i).GetLabel())
      self.assertEqual(output.GetPoint(i + input.GetPointNumber()).GetLabel(), input.GetPoint(i).GetLabel() + '_2')
    for i in range(0,input. GetAnalogNumber()):
      self.assertEqual(output.GetAnalog(i).GetLabel(), input.GetAnalog(i).GetLabel())
      self.assertEqual(output.GetAnalog(i + input.GetAnalogNumber()).GetLabel(), input.GetAnalog(i).GetLabel() + '_2')
    md = output.GetMetaData()
    fp = md.GetChild('FORCE_PLATFORM')
    used = fp.GetChild('USED').GetInfo().ToInt(0)
    self.assertEqual(used, 4)
    corners = fp.GetChild('CORNERS').GetInfo()
    self.assertEqual(corners.GetDimensions()[2], used)
    self.assertEqual(len(corners.ToDouble()), 48)
    corners2Val = input.GetMetaData().GetChild('FORCE_PLATFORM').GetChild('CORNERS').GetInfo().ToDouble()
    for i in range(0,24):
      self.assertAlmostEqual(corners.ToDouble(i), corners2Val[i], 5)
      self.assertAlmostEqual(corners.ToDouble(i+24), corners2Val[i], 5)
    channel = fp.GetChild('CHANNEL').GetInfo()
    self.assertEqual(channel.GetDimensions()[1], used)
    self.assertEqual(len(channel.ToInt()), 24)
    channel2Val = input.GetMetaData().GetChild('FORCE_PLATFORM').GetChild('CHANNEL').GetInfo().ToInt()
    for i in range(0,12):
      self.assertEqual(channel.ToInt(i), channel2Val[i])
      self.assertEqual(channel.ToInt(i+12), channel2Val[i] + 12)
    type = fp.GetChild('TYPE').GetInfo()
    self.assertEqual(type.GetDimensions()[0], used)
    self.assertEqual(len(type.ToInt()), 4)
    type2Val = input.GetMetaData().GetChild('FORCE_PLATFORM').GetChild('TYPE').GetInfo().ToInt()
    for i in range(0,2):
      self.assertEqual(type.ToInt(i), type2Val[i])
      self.assertEqual(type.ToInt(i+2), type2Val[i])
    origin = fp.GetChild('ORIGIN').GetInfo()
    self.assertEqual(origin.GetDimensions()[1], used)
    self.assertEqual(len(origin.ToDouble()), 12)
    origin2Val = input.GetMetaData().GetChild('FORCE_PLATFORM').GetChild('ORIGIN').GetInfo().ToDouble()
    for i in range(0,6):
      self.assertAlmostEqual(origin.ToDouble(i), origin2Val[i], 5)
      self.assertAlmostEqual(origin.ToDouble(i+6), origin2Val[i], 5)
    mdPointIt = md.FindChild('POINT')
    self.assertTrue(mdPointIt != md.End())
    if (mdPointIt != md.End()):
      self.assertEqual(mdPointIt.value().GetChildNumber(), 2)
      self.assertTrue(mdPointIt.value().FindChild('X_SCREEN') != mdPointIt.value().End())
      self.assertTrue(mdPointIt.value().FindChild('Y_SCREEN') != mdPointIt.value().End())
    self.assertTrue(md.FindChild('ANALOG') == md.End())
    self.assertTrue(md.FindChild('EVENT') == md.End())
    self.assertTrue(md.FindChild('TRIAL') != md.End())
    self.assertTrue(md.FindChild('SUBJECTS') != md.End())
    self.assertTrue(md.FindChild('SEG') != md.End())
    self.assertTrue(md.FindChild('EVENT_CONTEXT') != md.End())
    self.assertEqual(md.GetChildNumber(), 6)
    start = md.GetChild('TRIAL').GetChild('ACTUAL_START_FIELD').GetInfo().ToInt()
    stop = md.GetChild('TRIAL').GetChild('ACTUAL_END_FIELD').GetInfo().ToInt()
    self.assertEqual(output.GetFirstFrame(), start[0] + start[1] * 65535)
    self.assertEqual(output.GetLastFrame(), stop[0] + stop[1] * 65535)
  